"""
Merge patent to patent similarity scores between min_year and max_year
Robustness check: Validation using random sample of 5 patents
2024-06-11
"""

import pandas as pd
import os
import numpy as np
import glob

os.chdir(r'/Volumes/Zihao_SSD2/PatentsView/')

def load_patent_dict(path="patentsberta/patent_dict.csv"):
    dfpatent = pd.read_csv(
        path,
        usecols=["idx", "patent_id", "patent_year"],
        dtype={"idx": "int64", "patent_id": int, "year": "int16"},
    )
    dfpatent.columns = ["patent_idx", "patent_id", "patent_year"]
    if dfpatent.duplicated(["patent_idx", "patent_id"]).any():
        print("warning. duplicated [patent_idx, patent_id] found in patent_dict.csv!")
        dfpatent.drop_duplicates(["patent_idx", "patent_id"], inplace=True)
    return dfpatent


def merge_years(min_year: int, max_year: int) -> pd.DataFrame:
    """
    merge all years of processed cosine similarity DataFrames
    Args:
        min_year: min patent year to merge
        max_year: max patent year to merge
    Returns:
        df ['patent_id', 'patent_year', 'cited_patent_id', 'cited_patent_year', 'sim_score', 'patent_idx', 'cited_patent_idx']
    """
    import glob

    print('Loading patent_dict_kpss.csv...')
    dfpatent = load_patent_dict()

    df_cited_patent = dfpatent.copy()
    df_cited_patent.rename(columns={"patent_idx": "cited_patent_idx"}, inplace=True)

    yr_list = []
    for yr in range(min_year, max_year + 1):
        _dir = glob.glob(f"patentsberta/patent_sim_combined/sim_{yr}_sample/*.csv")
        print(f"start process and merge {len(_dir)} files for year {yr}")

        df_file = pd.concat(
            [
                pd.read_csv(
                    file,
                    dtype={
                        "patent_idx": int,
                        "cited_patent_idx": int,
                        "sim_score": "float32",
                    },
                )
                for file in _dir
            ],
            ignore_index=True,
        )
        x = len(df_file)
        df_file = df_file.merge(dfpatent, on=["patent_idx"]).merge(
            df_cited_patent, on=["cited_patent_idx"]
        )
        print(len(df_file))
        assert x==len(df_file)
        df_file = df_file.sort_values(
            ["patent_idx", "sim_score", "cited_patent_idx"], ascending=[True, False, False]
        )

        yr_list.append(df_file)
        del df_file, x

    df = pd.concat(yr_list, ignore_index=True)
    del yr_list

    df.sort_values(["patent_idx", "sim_score", "cited_patent_idx"], ascending=[True, False, False], inplace=True)
    df = df.rename(columns={
        'patent_id_x': 'patent_id',
        'patent_year_x': 'patent_year',
        'patent_id_y': 'cited_patent_id',
        'patent_year_y': 'cited_patent_year'
    })[['patent_id', 'patent_year', 'cited_patent_id', 'cited_patent_year', 'sim_score', 'patent_idx', 'cited_patent_idx']]

    return df

df_sim = merge_years(1981, 2015)
print("nrows of df_sim: ", len(df_sim))
df_sim.to_csv("cleandata/sim_score_1981_2015_sample.csv", index=False) 